{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "97086e6f",
   "metadata": {},
   "source": [
    "# Lesson 4: Azure OpenAI Function Calling Feature"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe2677c5",
   "metadata": {},
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b9a66f8",
   "metadata": {},
   "source": [
    "**Note**: The pre-configured cloud resource grants you access to the Azure OpenAI GPT model. The key and endpoint provided below are intended for teaching purposes only. Your notebook environment is already set up with the necessary keys, which may differ from those used by the instructor during the filming."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f871d539-bb14-4f6b-9212-1222a1227a35",
   "metadata": {
    "height": 166
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from openai import AzureOpenAI\n",
    "import json\n",
    "\n",
    "client = AzureOpenAI(\n",
    "  azure_endpoint = os.getenv(\"AZURE_OPENAI_ENDPOINT\"),\n",
    "  api_key=os.getenv(\"AZURE_OPENAI_API_KEY\"),\n",
    "  api_version=\"2023-05-15\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4673778b",
   "metadata": {},
   "source": [
    "## 1. Using an illustrative example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7f93802-cb53-42cd-b79b-a8b144f73efb",
   "metadata": {
    "height": 370
   },
   "outputs": [],
   "source": [
    "def get_current_weather(location, unit=\"fahrenheit\"):\n",
    "    \"\"\"Get the current weather in a given location. \n",
    "    The default unit when not specified is fahrenheit\"\"\"\n",
    "    if \"new york\" in location.lower():\n",
    "        return json.dumps(\n",
    "            {\"location\": \"New York\", \"temperature\": \"40\", \"unit\": unit}\n",
    "        )\n",
    "    elif \"san francisco\" in location.lower():\n",
    "        return json.dumps(\n",
    "            {\"location\": \"San Francisco\", \"temperature\": \"50\", \"unit\": unit}\n",
    "        )\n",
    "    elif \"las vegas\" in location.lower():\n",
    "        return json.dumps(\n",
    "            {\"location\": \"Las Vegas\", \"temperature\": \"70\", \"unit\": unit}\n",
    "        )\n",
    "    else:\n",
    "        return json.dumps(\n",
    "            {\"location\": location, \"temperature\": \"unknown\"}\n",
    "        )\n",
    "\n",
    "get_current_weather(\"New York\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4aa6f455",
   "metadata": {},
   "source": [
    "### Define the tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f37fd6aa-ce79-4f01-bb96-ccb1b5425743",
   "metadata": {
    "height": 676
   },
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\"role\": \"user\",\n",
    "     \"content\": \"\"\"What's the weather like in San Francisco,\n",
    "                   New York, and Las Vegass?\"\"\"\n",
    "    }\n",
    "]\n",
    "\n",
    "tools = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_current_weather\",\n",
    "            \"description\": \"\"\"Get the current weather in a given\n",
    "                              location.The default unit when not\n",
    "                              specified is fahrenheit\"\"\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"location\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"\"\"The city and state,\n",
    "                                        e.g. San Francisco, CA\"\"\",\n",
    "                    },\n",
    "                    \"unit\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"default\":\"fahrenheit\",\n",
    "                        \"enum\": [ \"fahrenheit\", \"celsius\"],\n",
    "                        \"description\": \"\"\"The messuring unit for\n",
    "                                          the temperature.\n",
    "                                          If not explicitly specified\n",
    "                                          the default unit is \n",
    "                                          fahrenheit\"\"\"\n",
    "                    },\n",
    "                },\n",
    "                \"required\": [\"location\"],\n",
    "            },\n",
    "        },\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04d25317",
   "metadata": {},
   "source": [
    "### Use the function calling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c4b13fc-aec6-4753-9cb7-d93243ba4a4d",
   "metadata": {
    "height": 608
   },
   "outputs": [],
   "source": [
    "response = client.chat.completions.create(\n",
    "    model=\"gpt-4-1106\",\n",
    "    messages=messages,\n",
    "    tools=tools,\n",
    "    tool_choice=\"auto\", \n",
    ")\n",
    "\n",
    "response_message = response.choices[0].message\n",
    "tool_calls = response_message.tool_calls\n",
    "\n",
    "if tool_calls:\n",
    "    print (tool_calls)\n",
    "    \n",
    "    available_functions = {\n",
    "        \"get_current_weather\": get_current_weather,\n",
    "    } \n",
    "    messages.append(response_message)  \n",
    "    \n",
    "    for tool_call in tool_calls:\n",
    "        function_name = tool_call.function.name\n",
    "        function_to_call = available_functions[function_name]\n",
    "        function_args = json.loads(tool_call.function.arguments)\n",
    "        function_response = function_to_call(\n",
    "            location=function_args.get(\"location\"),\n",
    "            unit=function_args.get(\"unit\"),\n",
    "        )\n",
    "        messages.append(\n",
    "            {\n",
    "                \"tool_call_id\": tool_call.id,\n",
    "                \"role\": \"tool\",\n",
    "                \"name\": function_name,\n",
    "                \"content\": function_response,\n",
    "            }\n",
    "        )  \n",
    "    print (messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73d48df7-7eeb-4d5e-ab95-4ae5c1071ded",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "second_response = client.chat.completions.create(\n",
    "            model=\"gpt-4-1106\",\n",
    "            messages=messages,\n",
    "        )\n",
    "print (second_response)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cdcafa56",
   "metadata": {},
   "source": [
    "## 2. Using our SQL database "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3fffee20",
   "metadata": {},
   "source": [
    "**Note**: To access the data locally, use the following code:\n",
    "\n",
    "```\n",
    "os.makedirs(\"data\",exist_ok=True)\n",
    "!wget https://covidtracking.com/data/download/all-states-history.csv -P ./data/\n",
    "file_url = \"./data/all-states-history.csv\"\n",
    "df = pd.read_csv(file_url).fillna(value = 0)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28b73aa0-5eb6-419c-9637-d66a916979d9",
   "metadata": {
    "height": 81
   },
   "outputs": [],
   "source": [
    "from sqlalchemy import create_engine\n",
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv(\"./data/all-states-history.csv\").fillna(value = 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4f045df-8811-4af5-9874-fc3f63b813c3",
   "metadata": {
    "height": 166
   },
   "outputs": [],
   "source": [
    "database_file_path = \"./db/test.db\"\n",
    "\n",
    "engine = create_engine(f'sqlite:///{database_file_path}')\n",
    "\n",
    "df.to_sql(\n",
    "    'all_states_history',\n",
    "    con=engine,\n",
    "    if_exists='replace',\n",
    "    index=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd8c6c50",
   "metadata": {},
   "source": [
    "### Create two functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d89d871e-0d85-40f1-9851-91aa953f7339",
   "metadata": {
    "height": 370
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sqlalchemy import text\n",
    "\n",
    "def get_hospitalized_increase_for_state_on_date(state_abbr, specific_date):\n",
    "    try:\n",
    "        query = f\"\"\"\n",
    "        SELECT date, hospitalizedIncrease\n",
    "        FROM all_states_history\n",
    "        WHERE state = '{state_abbr}' AND date = '{specific_date}';\n",
    "        \"\"\"\n",
    "        query = text(query)\n",
    "\n",
    "        with engine.connect() as connection:\n",
    "            result = pd.read_sql_query(query, connection)\n",
    "        if not result.empty:\n",
    "            return result.to_dict('records')[0]\n",
    "        else:\n",
    "            return np.nan\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        return np.nan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b6106ea-799d-4cbe-b75a-26a79627cfd4",
   "metadata": {
    "height": 319
   },
   "outputs": [],
   "source": [
    "def get_positive_cases_for_state_on_date(state_abbr, specific_date):\n",
    "    try:\n",
    "        query = f\"\"\"\n",
    "        SELECT date, state, positiveIncrease AS positive_cases\n",
    "        FROM all_states_history\n",
    "        WHERE state = '{state_abbr}' AND date = '{specific_date}';\n",
    "        \"\"\"\n",
    "        query = text(query)\n",
    "\n",
    "        with engine.connect() as connection:\n",
    "            result = pd.read_sql_query(query, connection)\n",
    "        if not result.empty:\n",
    "            return result.to_dict('records')[0]\n",
    "        else:\n",
    "            return np.nan\n",
    "    except Exception as e:\n",
    "        print(e)\n",
    "        return np.nan"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d235c9f0-9377-46d1-880c-c2a5da3f2bc9",
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "get_hospitalized_increase_for_state_on_date(\"AK\",\"2021-03-05\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58a111f7",
   "metadata": {},
   "source": [
    "### Execute the function calling against the SQL database"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97c99de7-2e22-467f-a5b8-771986692fc2",
   "metadata": {
    "height": 115
   },
   "outputs": [],
   "source": [
    "messages = [\n",
    "    {\"role\": \"user\",\n",
    "     \"content\": \"\"\" how many hospitalized people we had in Alaska\n",
    "                    the 2021-03-05?\"\"\"\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e22ddeb-911b-4169-810d-33a2cdbc55da",
   "metadata": {
    "height": 931
   },
   "outputs": [],
   "source": [
    "tools_sql = [\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_hospitalized_increase_for_state_on_date\",\n",
    "            \"description\": \"\"\"Retrieves the daily increase in\n",
    "                              hospitalizations for a specific state\n",
    "                              on a specific date.\"\"\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"state_abbr\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"\"\"The abbreviation of the state\n",
    "                                          (e.g., 'NY', 'CA').\"\"\"\n",
    "                    },\n",
    "                    \"specific_date\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"\"\"The specific date for\n",
    "                                          the query in 'YYYY-MM-DD'\n",
    "                                          format.\"\"\"\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"state_abbr\", \"specific_date\"]\n",
    "            }\n",
    "        }\n",
    "    },\n",
    "    {\n",
    "        \"type\": \"function\",\n",
    "        \"function\": {\n",
    "            \"name\": \"get_positive_cases_for_state_on_date\",\n",
    "            \"description\": \"\"\"Retrieves the daily increase in \n",
    "                              positive cases for a specific state\n",
    "                              on a specific date.\"\"\",\n",
    "            \"parameters\": {\n",
    "                \"type\": \"object\",\n",
    "                \"properties\": {\n",
    "                    \"state_abbr\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"\"\"The abbreviation of the \n",
    "                                          state (e.g., 'NY', 'CA').\"\"\"\n",
    "                    },\n",
    "                    \"specific_date\": {\n",
    "                        \"type\": \"string\",\n",
    "                        \"description\": \"\"\"The specific date for the\n",
    "                                          query in 'YYYY-MM-DD'\n",
    "                                          format.\"\"\"\n",
    "                    }\n",
    "                },\n",
    "                \"required\": [\"state_abbr\", \"specific_date\"]\n",
    "            }\n",
    "        }\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e1c0ba8-7b79-4548-9d8a-32d905b24fb9",
   "metadata": {
    "height": 625
   },
   "outputs": [],
   "source": [
    "response = client.chat.completions.create(\n",
    "    model=\"gpt-4-1106\",\n",
    "    messages=messages,\n",
    "    tools=tools_sql,\n",
    "    tool_choice=\"auto\",\n",
    ")\n",
    "\n",
    "response_message = response.choices[0].message\n",
    "tool_calls = response_message.tool_calls\n",
    "\n",
    "if tool_calls:\n",
    "    print (tool_calls)\n",
    "    \n",
    "    available_functions = {\n",
    "        \"get_positive_cases_for_state_on_date\": get_positive_cases_for_state_on_date,\n",
    "        \"get_hospitalized_increase_for_state_on_date\":get_hospitalized_increase_for_state_on_date\n",
    "    }  \n",
    "    messages.append(response_message)  \n",
    "   \n",
    "    for tool_call in tool_calls:\n",
    "        function_name = tool_call.function.name\n",
    "        function_to_call = available_functions[function_name]\n",
    "        function_args = json.loads(tool_call.function.arguments)\n",
    "        function_response = function_to_call(\n",
    "            state_abbr=function_args.get(\"state_abbr\"),\n",
    "            specific_date=function_args.get(\"specific_date\"),\n",
    "        )\n",
    "        messages.append(\n",
    "            {\n",
    "                \"tool_call_id\": tool_call.id,\n",
    "                \"role\": \"tool\",\n",
    "                \"name\": function_name,\n",
    "                \"content\": str(function_response),\n",
    "            }\n",
    "        ) \n",
    "    print(messages)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1958ad27-0656-4208-98b9-bd813a5af0ee",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "second_response = client.chat.completions.create(\n",
    "            model=\"gpt-4-1106\",\n",
    "            messages=messages,\n",
    "        )\n",
    "print (second_response)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
